{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "476e2f14",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# 自然语言推断与数据集\n",
    ":label:`sec_natural-language-inference-and-dataset`\n",
    "\n",
    "在 :numref:`sec_sentiment`中，我们讨论了情感分析问题。这个任务的目的是将单个文本序列分类到预定义的类别中，例如一组情感极性中。然而，当需要决定一个句子是否可以从另一个句子推断出来，或者需要通过识别语义等价的句子来消除句子间冗余时，知道如何对一个文本序列进行分类是不够的。相反，我们需要能够对成对的文本序列进行推断。\n",
    "\n",
    "## 自然语言推断\n",
    "\n",
    "*自然语言推断*（natural language inference）主要研究\n",
    "*假设*（hypothesis）是否可以从*前提*（premise）中推断出来，\n",
    "其中两者都是文本序列。\n",
    "换言之，自然语言推断决定了一对文本序列之间的逻辑关系。这类关系通常分为三种类型：\n",
    "\n",
    "* *蕴涵*（entailment）：假设可以从前提中推断出来。\n",
    "* *矛盾*（contradiction）：假设的否定可以从前提中推断出来。\n",
    "* *中性*（neutral）：所有其他情况。\n",
    "\n",
    "自然语言推断也被称为识别文本蕴涵任务。\n",
    "例如，下面的一个文本对将被贴上“蕴涵”的标签，因为假设中的“表白”可以从前提中的“拥抱”中推断出来。\n",
    "\n",
    ">前提：两个女人拥抱在一起。\n",
    "\n",
    ">假设：两个女人在示爱。\n",
    "\n",
    "下面是一个“矛盾”的例子，因为“运行编码示例”表示“不睡觉”，而不是“睡觉”。\n",
    "\n",
    ">前提：一名男子正在运行Dive Into Deep Learning的编码示例。\n",
    "\n",
    ">假设：该男子正在睡觉。\n",
    "\n",
    "第三个例子显示了一种“中性”关系，因为“正在为我们表演”这一事实无法推断出“出名”或“不出名”。\n",
    "\n",
    ">前提：音乐家们正在为我们表演。\n",
    "\n",
    ">假设：音乐家很有名。\n",
    "\n",
    "自然语言推断一直是理解自然语言的中心话题。它有着广泛的应用，从信息检索到开放领域的问答。为了研究这个问题，我们将首先研究一个流行的自然语言推断基准数据集。\n",
    "\n",
    "## 斯坦福自然语言推断（SNLI）数据集\n",
    "\n",
    "[**斯坦福自然语言推断语料库（Stanford Natural Language Inference，SNLI）**]是由500000多个带标签的英语句子对组成的集合 :cite:`Bowman.Angeli.Potts.ea.2015`。我们在路径`../data/snli_1.0`中下载并存储提取的SNLI数据集。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "70033cdc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T17:00:51.725473Z",
     "iopub.status.busy": "2022-12-07T17:00:51.725143Z",
     "iopub.status.idle": "2022-12-07T17:00:58.509145Z",
     "shell.execute_reply": "2022-12-07T17:00:58.508303Z"
    },
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l\n",
    "\n",
    "#@save\n",
    "d2l.DATA_HUB['SNLI'] = (\n",
    "    'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',\n",
    "    '9fcde07509c7e87ec61c640c1b2753d9041758e4')\n",
    "\n",
    "data_dir = d2l.download_extract('SNLI')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eead0f32",
   "metadata": {
    "origin_pos": 4
   },
   "source": [
    "### [**读取数据集**]\n",
    "\n",
    "原始的SNLI数据集包含的信息比我们在实验中真正需要的信息丰富得多。因此，我们定义函数`read_snli`以仅提取数据集的一部分，然后返回前提、假设及其标签的列表。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7a4adf0f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T17:00:58.513149Z",
     "iopub.status.busy": "2022-12-07T17:00:58.512779Z",
     "iopub.status.idle": "2022-12-07T17:00:58.520419Z",
     "shell.execute_reply": "2022-12-07T17:00:58.519708Z"
    },
    "origin_pos": 5,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "def read_snli(data_dir, is_train):\n",
    "    \"\"\"将SNLI数据集解析为前提、假设和标签\"\"\"\n",
    "    def extract_text(s):\n",
    "        # 删除我们不会使用的信息\n",
    "        s = re.sub('\\\\(', '', s)\n",
    "        s = re.sub('\\\\)', '', s)\n",
    "        # 用一个空格替换两个或多个连续的空格\n",
    "        s = re.sub('\\\\s{2,}', ' ', s)\n",
    "        return s.strip()\n",
    "    label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}\n",
    "    file_name = os.path.join(data_dir, 'snli_1.0_train.txt'\n",
    "                             if is_train else 'snli_1.0_test.txt')\n",
    "    with open(file_name, 'r') as f:\n",
    "        rows = [row.split('\\t') for row in f.readlines()[1:]]\n",
    "    premises = [extract_text(row[1]) for row in rows if row[0] in label_set]\n",
    "    hypotheses = [extract_text(row[2]) for row in rows if row[0] \\\n",
    "                in label_set]\n",
    "    labels = [label_set[row[0]] for row in rows if row[0] in label_set]\n",
    "    return premises, hypotheses, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7648e401",
   "metadata": {
    "origin_pos": 6
   },
   "source": [
    "现在让我们[**打印前3对**]前提和假设，以及它们的标签（“0”“1”和“2”分别对应于“蕴涵”“矛盾”和“中性”）。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3d3e7ff4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T17:00:58.523326Z",
     "iopub.status.busy": "2022-12-07T17:00:58.523060Z",
     "iopub.status.idle": "2022-12-07T17:01:10.535895Z",
     "shell.execute_reply": "2022-12-07T17:01:10.535070Z"
    },
    "origin_pos": 7,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "前提： A person on a horse jumps over a broken down airplane .\n",
      "假设： A person is training his horse for a competition .\n",
      "标签： 2\n",
      "前提： A person on a horse jumps over a broken down airplane .\n",
      "假设： A person is at a diner , ordering an omelette .\n",
      "标签： 1\n",
      "前提： A person on a horse jumps over a broken down airplane .\n",
      "假设： A person is outdoors , on a horse .\n",
      "标签： 0\n"
     ]
    }
   ],
   "source": [
    "train_data = read_snli(data_dir, is_train=True)\n",
    "for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):\n",
    "    print('前提：', x0)\n",
    "    print('假设：', x1)\n",
    "    print('标签：', y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6121717f",
   "metadata": {
    "origin_pos": 8
   },
   "source": [
    "训练集约有550000对，测试集约有10000对。下面显示了训练集和测试集中的三个[**标签“蕴涵”“矛盾”和“中性”是平衡的**]。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6b95f69b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T17:01:10.541102Z",
     "iopub.status.busy": "2022-12-07T17:01:10.540564Z",
     "iopub.status.idle": "2022-12-07T17:01:10.854334Z",
     "shell.execute_reply": "2022-12-07T17:01:10.853505Z"
    },
    "origin_pos": 9,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[183416, 183187, 182764]\n",
      "[3368, 3237, 3219]\n"
     ]
    }
   ],
   "source": [
    "test_data = read_snli(data_dir, is_train=False)\n",
    "for data in [train_data, test_data]:\n",
    "    print([[row for row in data[2]].count(i) for i in range(3)])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64afc0f3",
   "metadata": {
    "origin_pos": 10
   },
   "source": [
    "### [**定义用于加载数据集的类**]\n",
    "\n",
    "下面我们来定义一个用于加载SNLI数据集的类。类构造函数中的变量`num_steps`指定文本序列的长度，使得每个小批量序列将具有相同的形状。换句话说，在较长序列中的前`num_steps`个标记之后的标记被截断，而特殊标记“&lt;pad&gt;”将被附加到较短的序列后，直到它们的长度变为`num_steps`。通过实现`__getitem__`功能，我们可以任意访问带有索引`idx`的前提、假设和标签。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6c33e562",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T17:01:10.857887Z",
     "iopub.status.busy": "2022-12-07T17:01:10.857321Z",
     "iopub.status.idle": "2022-12-07T17:01:10.865250Z",
     "shell.execute_reply": "2022-12-07T17:01:10.864549Z"
    },
    "origin_pos": 12,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "class SNLIDataset(torch.utils.data.Dataset):\n",
    "    \"\"\"用于加载SNLI数据集的自定义数据集\"\"\"\n",
    "    def __init__(self, dataset, num_steps, vocab=None):\n",
    "        self.num_steps = num_steps\n",
    "        all_premise_tokens = d2l.tokenize(dataset[0])\n",
    "        all_hypothesis_tokens = d2l.tokenize(dataset[1])\n",
    "        if vocab is None:\n",
    "            self.vocab = d2l.Vocab(all_premise_tokens + \\\n",
    "                all_hypothesis_tokens, min_freq=5, reserved_tokens=['<pad>'])\n",
    "        else:\n",
    "            self.vocab = vocab\n",
    "        self.premises = self._pad(all_premise_tokens)\n",
    "        self.hypotheses = self._pad(all_hypothesis_tokens)\n",
    "        self.labels = torch.tensor(dataset[2])\n",
    "        print('read ' + str(len(self.premises)) + ' examples')\n",
    "\n",
    "    def _pad(self, lines):\n",
    "        return torch.tensor([d2l.truncate_pad(\n",
    "            self.vocab[line], self.num_steps, self.vocab['<pad>'])\n",
    "                         for line in lines])\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.premises)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5fefc51",
   "metadata": {
    "origin_pos": 14
   },
   "source": [
    "### [**整合代码**]\n",
    "\n",
    "现在，我们可以调用`read_snli`函数和`SNLIDataset`类来下载SNLI数据集，并返回训练集和测试集的`DataLoader`实例，以及训练集的词表。值得注意的是，我们必须使用从训练集构造的词表作为测试集的词表。因此，在训练集中训练的模型将不知道来自测试集的任何新词元。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7c18ffd2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T17:01:10.868549Z",
     "iopub.status.busy": "2022-12-07T17:01:10.867902Z",
     "iopub.status.idle": "2022-12-07T17:01:10.873371Z",
     "shell.execute_reply": "2022-12-07T17:01:10.872640Z"
    },
    "origin_pos": 16,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "def load_data_snli(batch_size, num_steps=50):\n",
    "    \"\"\"下载SNLI数据集并返回数据迭代器和词表\"\"\"\n",
    "    num_workers = d2l.get_dataloader_workers()\n",
    "    data_dir = d2l.download_extract('SNLI')\n",
    "    train_data = read_snli(data_dir, True)\n",
    "    test_data = read_snli(data_dir, False)\n",
    "    train_set = SNLIDataset(train_data, num_steps)\n",
    "    test_set = SNLIDataset(test_data, num_steps, train_set.vocab)\n",
    "    train_iter = torch.utils.data.DataLoader(train_set, batch_size,\n",
    "                                             shuffle=True,\n",
    "                                             num_workers=num_workers)\n",
    "    test_iter = torch.utils.data.DataLoader(test_set, batch_size,\n",
    "                                            shuffle=False,\n",
    "                                            num_workers=num_workers)\n",
    "    return train_iter, test_iter, train_set.vocab"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5f3683b",
   "metadata": {
    "origin_pos": 18
   },
   "source": [
    "在这里，我们将批量大小设置为128时，将序列长度设置为50，并调用`load_data_snli`函数来获取数据迭代器和词表。然后我们打印词表大小。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "80577e00",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T17:01:10.876575Z",
     "iopub.status.busy": "2022-12-07T17:01:10.875993Z",
     "iopub.status.idle": "2022-12-07T17:01:48.412103Z",
     "shell.execute_reply": "2022-12-07T17:01:48.411314Z"
    },
    "origin_pos": 19,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "read 549367 examples\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "read 9824 examples\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "18678"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_iter, test_iter, vocab = load_data_snli(128, 50)\n",
    "len(vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3b25fa4",
   "metadata": {
    "origin_pos": 20
   },
   "source": [
    "现在我们打印第一个小批量的形状。与情感分析相反，我们有分别代表前提和假设的两个输入`X[0]`和`X[1]`。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a7724e3f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T17:01:48.415607Z",
     "iopub.status.busy": "2022-12-07T17:01:48.415044Z",
     "iopub.status.idle": "2022-12-07T17:01:48.722882Z",
     "shell.execute_reply": "2022-12-07T17:01:48.721898Z"
    },
    "origin_pos": 21,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([128, 50])\n",
      "torch.Size([128, 50])\n",
      "torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "for X, Y in train_iter:\n",
    "    print(X[0].shape)\n",
    "    print(X[1].shape)\n",
    "    print(Y.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0c38ba5",
   "metadata": {
    "origin_pos": 22
   },
   "source": [
    "## 小结\n",
    "\n",
    "* 自然语言推断研究“假设”是否可以从“前提”推断出来，其中两者都是文本序列。\n",
    "* 在自然语言推断中，前提和假设之间的关系包括蕴涵关系、矛盾关系和中性关系。\n",
    "* 斯坦福自然语言推断（SNLI）语料库是一个比较流行的自然语言推断基准数据集。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 机器翻译长期以来一直是基于翻译输出和翻译真实值之间的表面$n$元语法匹配来进行评估的。可以设计一种用自然语言推断来评价机器翻译结果的方法吗？\n",
    "1. 我们如何更改超参数以减小词表大小？\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37719cd1",
   "metadata": {
    "origin_pos": 24,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "[Discussions](https://discuss.d2l.ai/t/5722)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}